import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal, TypeVar, cast

import gymnasium as gym
import numpy as np
import torch

from tianshou.algorithm import Algorithm
from tianshou.algorithm.algorithm_base import (
    OnPolicyAlgorithm,
    Policy,
    TrainingStats,
)
from tianshou.algorithm.optim import OptimizerFactory
from tianshou.data import (
    Batch,
    ReplayBuffer,
    SequenceSummaryStats,
    to_torch,
    to_torch_as,
)
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
    BatchWithReturnsProtocol,
    DistBatchProtocol,
    ObsBatchProtocol,
    RolloutBatchProtocol,
)
from tianshou.utils import RunningMeanStd
from tianshou.utils.net.common import (
    AbstractContinuousActorProbabilistic,
    AbstractDiscreteActor,
    ActionReprNet,
)
from tianshou.utils.net.discrete import dist_fn_categorical_from_logits

log = logging.getLogger(__name__)


# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None

TDistFnContinuous = Callable[
    [tuple[torch.Tensor, torch.Tensor]],
    torch.distributions.Distribution,
]
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Distribution]

TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete


@dataclass(kw_only=True)
class LossSequenceTrainingStats(TrainingStats):
    loss: SequenceSummaryStats


@dataclass(kw_only=True)
class SimpleLossTrainingStats(TrainingStats):
    loss: float


class ProbabilisticActorPolicy(Policy):
    """
    A policy that outputs (representations of) probability distributions from which
    actions can be sampled.
    """

    def __init__(
        self,
        *,
        actor: AbstractContinuousActorProbabilistic | AbstractDiscreteActor | ActionReprNet,
        dist_fn: TDistFnDiscrOrCont,
        deterministic_eval: bool = False,
        action_space: gym.Space,
        observation_space: gym.Space | None = None,
        action_scaling: bool = True,
        action_bound_method: Literal["clip", "tanh"] | None = "clip",
    ) -> None:
        """
        :param actor: the actor network following the rules:
            If `self.action_type == "discrete"`: (`s_B` -> `action_values_BA`).
            If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
        :param dist_fn: the function/type which creates a distribution from the actor output,
            i.e. it maps the tensor(s) generated by the actor to a torch distribution.
            For continuous action spaces, the output is typically a pair of tensors
            (mean, std) and the distribution is a Gaussian distribution.
            For discrete action spaces, the output is typically a tensor of unnormalized
            log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities
            which can serve as the parameters of a Categorical distribution.
            Note that if the actor uses softmax activation in its final layer, it will produce
            probabilities, whereas if it uses no activation, it can be considered as producing
            "logits".
            As a user, you are responsible for ensuring that the distribution
            is compatible with the output of the actor model and the action space.
        :param deterministic_eval: flag indicating whether the policy should use deterministic
            actions (using the mode of the action distribution) instead of stochastic ones
            (using random sampling) during evaluation.
            When enabled, the policy will always select the most probable action according to
            the learned distribution during evaluation phases, while still using stochastic
            sampling during training. This creates a clear distinction between exploration
            (training) and exploitation (evaluation) behaviors.
            Deterministic actions are generally preferred for final deployment and reproducible
            evaluation as they provide consistent behavior, reduce variance in performance
            metrics, and are more interpretable for human observers.
            Note that this parameter only affects behavior when the policy is not within a
            training step. When collecting rollouts for training, actions remain stochastic
            regardless of this setting to maintain proper exploration behaviour.
        :param action_space: the environment's action space.
        :param observation_space: the environment's observation space.
        :param action_scaling: flag indicating whether, for continuous action spaces, actions
            should be scaled from the standard neural network output range [-1, 1] to the
            environment's action space range [action_space.low, action_space.high].
            This applies to continuous action spaces only (gym.spaces.Box) and has no effect
            for discrete spaces.
            When enabled, policy outputs are expected to be in the normalized range [-1, 1]
            (after bounding), and are then linearly transformed to the actual required range.
            This improves neural network training stability, allows the same algorithm to work
            across environments with different action ranges, and standardizes exploration
            strategies.
            Should be disabled if the actor model already produces outputs in the correct range.
        :param action_bound_method: the method used for bounding actions in continuous action spaces
            to the range [-1, 1] before scaling them to the environment's action space (provided
            that `action_scaling` is enabled).
            This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None
            for discrete spaces.
            When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this
            range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly
            constrains outputs to [-1, 1] while preserving gradients.
            The choice of bounding method affects both training dynamics and exploration behavior.
            Clipping provides hard boundaries but may create plateau regions in the gradient
            landscape, while tanh provides smoother transitions but can compress sensitivity
            near the boundaries.
            Should be set to None if the actor model inherently produces bounded outputs.
            Typically used together with `action_scaling=True`.
        """
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            action_scaling=action_scaling,
            action_bound_method=action_bound_method,
        )
        if action_scaling:
            try:
                max_action = float(actor.max_action)
                if np.isclose(max_action, 1.0):
                    warnings.warn(
                        "action_scaling and action_bound_method are only intended "
                        "to deal with unbounded model action space, but found actor model "
                        f"bound action space with max_action={actor.max_action}. "
                        "Consider using unbounded=True option of the actor model, "
                        "or set action_scaling to False and action_bound_method to None.",
                    )
            except BaseException:
                pass

        self.actor = actor
        self.dist_fn = dist_fn
        self._eps = 1e-8
        self.deterministic_eval = deterministic_eval

    def forward(
        self,
        batch: ObsBatchProtocol,
        state: dict | BatchProtocol | np.ndarray | None = None,
    ) -> DistBatchProtocol:
        """Compute action over the given batch data by applying the actor.

        Will sample from the dist_fn, if appropriate.
        Returns a new object representing the processed batch data
        (contrary to other methods that modify the input batch inplace).
        """
        action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
        # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
        # therefore action_dist_input_BD is equivalent to logits_BA
        # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
        # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
        dist = self.dist_fn(action_dist_input_BD)

        act_B = (
            dist.mode
            if self.deterministic_eval and not self.is_within_training_step
            else dist.sample()
        )
        # act is of dimension BA in continuous case and of dimension B in discrete
        result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
        return cast(DistBatchProtocol, result)


class DiscreteActorPolicy(ProbabilisticActorPolicy):
    def __init__(
        self,
        *,
        actor: AbstractDiscreteActor | ActionReprNet,
        dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits,
        deterministic_eval: bool = False,
        action_space: gym.Space,
        observation_space: gym.Space | None = None,
    ) -> None:
        """
        :param actor: the actor network following the rules: (`s_B` -> `dist_input_BD`).
        :param dist_fn: the function/type which creates a distribution from the actor output,
            i.e. it maps the tensor(s) generated by the actor to a torch distribution.
            For discrete action spaces, the output is typically a tensor of unnormalized
            log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities
            which serve as the parameters of a Categorical distribution.
            Note that if the actor uses softmax activation in its final layer, it will produce
            probabilities, whereas if it uses no activation, it can be considered as producing
            "logits".
            As a user, you are responsible for ensuring that the distribution
            is compatible with the output of the actor model and the action space.
        :param deterministic_eval: flag indicating whether the policy should use deterministic
            actions (using the mode of the action distribution) instead of stochastic ones
            (using random sampling) during evaluation.
            When enabled, the policy will always select the most probable action according to
            the learned distribution during evaluation phases, while still using stochastic
            sampling during training. This creates a clear distinction between exploration
            (training) and exploitation (evaluation) behaviors.
            Deterministic actions are generally preferred for final deployment and reproducible
            evaluation as they provide consistent behavior, reduce variance in performance
            metrics, and are more interpretable for human observers.
            Note that this parameter only affects behavior when the policy is not within a
            training step. When collecting rollouts for training, actions remain stochastic
            regardless of this setting to maintain proper exploration behaviour.
        :param action_space: the environment's (discrete) action space.
        :param observation_space: the environment's observation space.
        """
        if not isinstance(action_space, gym.spaces.Discrete):
            raise ValueError(f"Action space must be an instance of Discrete; got {action_space}")
        super().__init__(
            actor=actor,
            dist_fn=dist_fn,
            deterministic_eval=deterministic_eval,
            action_space=action_space,
            observation_space=observation_space,
            action_scaling=False,
            action_bound_method=None,
        )


TActorPolicy = TypeVar("TActorPolicy", bound=ProbabilisticActorPolicy)


class DiscountedReturnComputation:
    def __init__(
        self,
        gamma: float = 0.99,
        return_standardization: bool = False,
    ):
        """
        :param gamma: the discount factor in [0, 1] for future rewards.
            This determines how much future rewards are valued compared to immediate ones.
            Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic"
            behavior. Higher values (closer to 1) make the agent value long-term rewards more,
            potentially improving performance in tasks where delayed rewards are important but
            increasing training variance by incorporating more environmental stochasticity.
            Typically set between 0.9 and 0.99 for most reinforcement learning tasks
        :param return_standardization: whether to standardize episode returns
            by subtracting the running mean and dividing by the running standard deviation.
            Note that this is known to be detrimental to performance in many cases!
        """
        assert 0.0 <= gamma <= 1.0, "discount factor gamma should be in [0, 1]"
        self.gamma = gamma
        self.return_standardization = return_standardization
        self.ret_rms = RunningMeanStd()
        self.eps = 1e-8

    def add_discounted_returns(
        self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray
    ) -> BatchWithReturnsProtocol:
        r"""Compute the discounted returns (Monte Carlo estimates) for each transition.

        They are added to the batch under the field `returns`.
        Note: this function will modify the input batch!

        .. math::
            G_t = \sum_{i=t}^T \gamma^{i-t}r_i

        where :math:`T` is the terminal time step, :math:`\gamma` is the
        discount factor, :math:`\gamma \in [0, 1]`.

        :param batch: a data batch which contains several episodes of data in
            sequential order. Mind that the end of each finished episode of batch
            should be marked by done flag, unfinished (or collecting) episodes will be
            recognized by buffer.unfinished_index().
        :param buffer: the corresponding replay buffer.
        :param indices: tell batch's location in buffer, batch is equal
            to buffer[indices].
        """
        v_s_ = np.full(indices.shape, self.ret_rms.mean)
        # gae_lambda = 1.0 means we use Monte Carlo estimate
        unnormalized_returns, _ = Algorithm.compute_episodic_return(
            batch,
            buffer,
            indices,
            v_s_=v_s_,
            gamma=self.gamma,
            gae_lambda=1.0,
        )
        if self.return_standardization:
            batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt(
                self.ret_rms.var + self.eps,
            )
            self.ret_rms.update(unnormalized_returns)
        else:
            batch.returns = unnormalized_returns
        return cast(BatchWithReturnsProtocol, batch)


class Reinforce(OnPolicyAlgorithm[ProbabilisticActorPolicy]):
    """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm."""

    def __init__(
        self,
        *,
        policy: ProbabilisticActorPolicy,
        gamma: float = 0.99,
        return_standardization: bool = False,
        optim: OptimizerFactory,
    ) -> None:
        """
        :param policy: the policy
        :param optim: the optimizer factory for the policy's model.
        :param gamma: the discount factor in [0, 1] for future rewards.
            This determines how much future rewards are valued compared to immediate ones.
            Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic"
            behavior. Higher values (closer to 1) make the agent value long-term rewards more,
            potentially improving performance in tasks where delayed rewards are important but
            increasing training variance by incorporating more environmental stochasticity.
            Typically set between 0.9 and 0.99 for most reinforcement learning tasks
        :param return_standardization: if True, will scale/standardize returns
            by subtracting the running mean and dividing by the running standard deviation.
            Can be detrimental to performance!
        """
        super().__init__(
            policy=policy,
        )
        self.discounted_return_computation = DiscountedReturnComputation(
            gamma=gamma,
            return_standardization=return_standardization,
        )
        self.optim = self._create_optimizer(self.policy, optim)

    def _preprocess_batch(
        self,
        batch: RolloutBatchProtocol,
        buffer: ReplayBuffer,
        indices: np.ndarray,
    ) -> BatchWithReturnsProtocol:
        return self.discounted_return_computation.add_discounted_returns(
            batch,
            buffer,
            indices,
        )

    # Needs BatchWithReturnsProtocol, which violates the substitution principle. But not a problem since it's a private method and
    # the remainder of the class was adjusted to provide the correct batch
    def _update_with_batch(  # type: ignore[override]
        self,
        batch: BatchWithReturnsProtocol,
        batch_size: int | None,
        repeat: int,
    ) -> LossSequenceTrainingStats:
        losses = []
        split_batch_size = batch_size or -1
        for _ in range(repeat):
            for minibatch in batch.split(split_batch_size, merge_last=True):
                result = self.policy(minibatch)
                dist = result.dist
                act = to_torch_as(minibatch.act, result.act)
                ret = to_torch(minibatch.returns, torch.float, result.act.device)
                log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
                loss = -(log_prob * ret).mean()
                self.optim.step(loss)
                losses.append(loss.item())

        return LossSequenceTrainingStats(loss=SequenceSummaryStats.from_sequence(losses))
